{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Retiarii Example - One-shot NAS"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This example will show Retiarii's ability to **express** and **explore** the model space for Neural Architecture Search and Hyper-Parameter Tuning in a simple way. The video demo is in [YouTube](https://youtu.be/3nEx9GMHYEk) and [Bilibili](https://www.bilibili.com/video/BV1c54y1V7vx/).\n",
    "\n",
    "Let's start the journey with Retiarii!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Express the Model Space"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Model space is defined by users to express a set of models that they want to explore, which contains potentially good-performing models. In Retiarii framework, a model space is defined with two parts: a base model and possible mutations on the base model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 1.1: Define the Base Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Defining a base model is almost the same as defining a PyTorch (or TensorFlow) model. Usually, you only need to replace the code ``import torch.nn as nn`` with ``import nni.retiarii.nn.pytorch as nn`` to use NNI wrapped PyTorch modules. Below is a very simple example of defining a base model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "import nni.retiarii.nn.pytorch as nn\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(3, 6, 3, padding=1)\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        self.conv2 = nn.Conv2d(6, 16, 3, padding=1)\n",
    "        self.conv3 = nn.Conv2d(16, 16, 1)\n",
    "\n",
    "        self.bn = nn.BatchNorm2d(16)\n",
    "\n",
    "        self.gap = nn.AdaptiveAvgPool2d(4)\n",
    "        self.fc1 = nn.Linear(16 * 4 * 4, 120)\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        self.fc3 = nn.Linear(84, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        bs = x.size(0)\n",
    "\n",
    "        x = self.pool(F.relu(self.conv1(x)))\n",
    "        x0 = F.relu(self.conv2(x))\n",
    "        x1 = F.relu(self.conv3(x0))\n",
    "\n",
    "        x1 += x0\n",
    "        x = self.pool(self.bn(x1))\n",
    "\n",
    "        x = self.gap(x).view(bs, -1)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "    \n",
    "model = Net()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 1.2: Define the Model Mutations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A base model is only one concrete model, not a model space. NNI provides APIs and primitives for users to express how the base model can be mutated, i.e., a model space that includes many models. The following will use inline Mutation APIs ``LayerChoice`` to choose a layer from candidate operations and use ``InputChoice`` to try out skip connection."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "import nni.retiarii.nn.pytorch as nn\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        # self.conv1 = nn.Conv2d(3, 6, 3, padding=1)\n",
    "        self.conv1 = nn.LayerChoice([nn.Conv2d(3, 6, 3, padding=1), nn.Conv2d(3, 6, 5, padding=2)])\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        # self.conv2 = nn.Conv2d(6, 16, 3, padding=1)\n",
    "        self.conv2 = nn.LayerChoice([nn.Conv2d(6, 16, 3, padding=1), nn.Conv2d(6, 16, 5, padding=2)])\n",
    "        self.conv3 = nn.Conv2d(16, 16, 1)\n",
    "\n",
    "        self.skipconnect = nn.InputChoice(n_candidates=2)\n",
    "        self.bn = nn.BatchNorm2d(16)\n",
    "\n",
    "        self.gap = nn.AdaptiveAvgPool2d(4)\n",
    "        self.fc1 = nn.Linear(16 * 4 * 4, 120)\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        self.fc3 = nn.Linear(84, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        bs = x.size(0)\n",
    "\n",
    "        x = self.pool(F.relu(self.conv1(x)))\n",
    "        x0 = F.relu(self.conv2(x))\n",
    "        x1 = F.relu(self.conv3(x0))\n",
    "\n",
    "        x1 = self.skipconnect([x1, x1+x0])\n",
    "        x = self.pool(self.bn(x1))\n",
    "\n",
    "        x = self.gap(x).view(bs, -1)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "    \n",
    "model = Net()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Explore the Model Space"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With a defined model space, users can explore the space in two ways. One is the multi-trial NAS method, which searchs by evaluating each sampled model independently. The other is using one-shot weight-sharing based search, which consumes much less computational resource compared to the first one. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this part, we focus on this **one-shot** approach. The principle of the One-shot approach is combining all the models in a model space into one big model (usually called super-model or super-graph). It takes charge of both search, training and testing, by training and evaluating this big model.\n",
    "\n",
    "Retiarii has supported some classic one-shot trainers, like DARTS trainer, ENAS trainer, ProxylessNAS trainer, Single-path trainer, and users can customize a new one-shot trainer according to the APIs provided by Retiarii conveniently.\n",
    "\n",
    "Here, we show an example to use DARTS trainer manually."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "[2021-06-07 11:12:22] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [1/391]  acc1 0.093750 (0.093750)  loss 2.286068 (2.286068)\n",
      "[2021-06-07 11:12:22] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [11/391]  acc1 0.093750 (0.089489)  loss 2.328799 (2.309416)\n",
      "[2021-06-07 11:12:23] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [21/391]  acc1 0.093750 (0.092262)  loss 2.302527 (2.309082)\n",
      "[2021-06-07 11:12:23] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [31/391]  acc1 0.109375 (0.099294)  loss 2.294730 (2.304962)\n",
      "[2021-06-07 11:12:23] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [41/391]  acc1 0.203125 (0.103277)  loss 2.284227 (2.302716)\n",
      "[2021-06-07 11:12:23] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [51/391]  acc1 0.078125 (0.106618)  loss 2.308704 (2.300639)\n",
      "[2021-06-07 11:12:23] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [61/391]  acc1 0.203125 (0.110143)  loss 2.258595 (2.298042)\n",
      "[2021-06-07 11:12:23] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [71/391]  acc1 0.078125 (0.112896)  loss 2.276706 (2.294709)\n",
      "[2021-06-07 11:12:24] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [81/391]  acc1 0.078125 (0.116898)  loss 2.309119 (2.292235)\n",
      "[2021-06-07 11:12:24] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [91/391]  acc1 0.093750 (0.118304)  loss 2.263757 (2.289659)\n",
      "[2021-06-07 11:12:24] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [101/391]  acc1 0.109375 (0.119431)  loss 2.260739 (2.287132)\n",
      "[2021-06-07 11:12:24] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [111/391]  acc1 0.109375 (0.121481)  loss 2.279930 (2.284314)\n",
      "[2021-06-07 11:12:24] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [121/391]  acc1 0.046875 (0.122934)  loss 2.270205 (2.281701)\n",
      "[2021-06-07 11:12:25] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [131/391]  acc1 0.156250 (0.125477)  loss 2.270163 (2.278612)\n",
      "[2021-06-07 11:12:25] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [141/391]  acc1 0.171875 (0.126551)  loss 2.233467 (2.276326)\n",
      "[2021-06-07 11:12:25] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [151/391]  acc1 0.109375 (0.127897)  loss 2.264694 (2.274296)\n",
      "[2021-06-07 11:12:25] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [161/391]  acc1 0.250000 (0.132279)  loss 2.259590 (2.271723)\n",
      "[2021-06-07 11:12:25] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [171/391]  acc1 0.093750 (0.134868)  loss 2.240986 (2.269037)\n",
      "[2021-06-07 11:12:25] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [181/391]  acc1 0.218750 (0.137690)  loss 2.218153 (2.266567)\n",
      "[2021-06-07 11:12:25] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [191/391]  acc1 0.078125 (0.140134)  loss 2.260816 (2.264373)\n",
      "[2021-06-07 11:12:26] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [201/391]  acc1 0.156250 (0.144123)  loss 2.191213 (2.261285)\n",
      "[2021-06-07 11:12:26] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [211/391]  acc1 0.125000 (0.146919)  loss 2.245425 (2.258747)\n",
      "[2021-06-07 11:12:26] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [221/391]  acc1 0.218750 (0.150028)  loss 2.216708 (2.255553)\n",
      "[2021-06-07 11:12:26] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [231/391]  acc1 0.250000 (0.153003)  loss 2.195549 (2.252894)\n",
      "[2021-06-07 11:12:26] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [241/391]  acc1 0.234375 (0.155666)  loss 2.169693 (2.249465)\n",
      "[2021-06-07 11:12:26] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [251/391]  acc1 0.218750 (0.158989)  loss 2.174878 (2.246355)\n",
      "[2021-06-07 11:12:27] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [261/391]  acc1 0.312500 (0.162775)  loss 2.117693 (2.243113)\n",
      "[2021-06-07 11:12:27] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [271/391]  acc1 0.265625 (0.166686)  loss 2.136203 (2.239288)\n",
      "[2021-06-07 11:12:27] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [281/391]  acc1 0.234375 (0.169095)  loss 2.213463 (2.236377)\n",
      "[2021-06-07 11:12:27] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [291/391]  acc1 0.218750 (0.171338)  loss 2.114096 (2.232892)\n",
      "[2021-06-07 11:12:27] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [301/391]  acc1 0.203125 (0.173432)  loss 2.134074 (2.229637)\n",
      "[2021-06-07 11:12:28] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [311/391]  acc1 0.265625 (0.175291)  loss 2.041354 (2.225920)\n",
      "[2021-06-07 11:12:28] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [321/391]  acc1 0.250000 (0.176840)  loss 2.081122 (2.222280)\n",
      "[2021-06-07 11:12:28] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [331/391]  acc1 0.140625 (0.178578)  loss 2.124206 (2.219168)\n",
      "[2021-06-07 11:12:28] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [341/391]  acc1 0.250000 (0.180169)  loss 2.077291 (2.215540)\n",
      "[2021-06-07 11:12:28] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [351/391]  acc1 0.250000 (0.182381)  loss 2.077531 (2.211650)\n",
      "[2021-06-07 11:12:28] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [361/391]  acc1 0.312500 (0.185033)  loss 2.016619 (2.207455)\n",
      "[2021-06-07 11:12:29] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [371/391]  acc1 0.250000 (0.187163)  loss 2.139604 (2.202785)\n",
      "[2021-06-07 11:12:29] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [381/391]  acc1 0.281250 (0.189099)  loss 2.033739 (2.198564)\n",
      "[2021-06-07 11:12:29] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [1/2] Step [391/391]  acc1 0.275000 (0.190441)  loss 1.988353 (2.194509)\n",
      "[2021-06-07 11:12:29] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [1/391]  acc1 0.296875 (0.296875)  loss 2.083627 (2.083627)\n",
      "[2021-06-07 11:12:30] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [11/391]  acc1 0.265625 (0.251420)  loss 2.042856 (2.050898)\n",
      "[2021-06-07 11:12:30] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [21/391]  acc1 0.234375 (0.273065)  loss 2.005307 (2.021047)\n",
      "[2021-06-07 11:12:30] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [31/391]  acc1 0.375000 (0.269657)  loss 1.934093 (2.014375)\n",
      "[2021-06-07 11:12:30] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [41/391]  acc1 0.265625 (0.277439)  loss 2.007705 (2.003260)\n",
      "[2021-06-07 11:12:30] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [51/391]  acc1 0.218750 (0.278799)  loss 2.014602 (2.001039)\n",
      "[2021-06-07 11:12:31] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [61/391]  acc1 0.187500 (0.278945)  loss 2.088407 (1.995837)\n",
      "[2021-06-07 11:12:31] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [71/391]  acc1 0.343750 (0.285651)  loss 1.894479 (1.988130)\n",
      "[2021-06-07 11:12:31] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [81/391]  acc1 0.281250 (0.289159)  loss 1.869002 (1.979012)\n",
      "[2021-06-07 11:12:31] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [91/391]  acc1 0.265625 (0.291552)  loss 1.848354 (1.971483)\n",
      "[2021-06-07 11:12:31] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [101/391]  acc1 0.406250 (0.290996)  loss 1.840711 (1.964297)\n",
      "[2021-06-07 11:12:31] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [111/391]  acc1 0.390625 (0.294764)  loss 1.905811 (1.958954)\n",
      "[2021-06-07 11:12:32] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [121/391]  acc1 0.250000 (0.296617)  loss 1.935214 (1.952315)\n",
      "[2021-06-07 11:12:32] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [131/391]  acc1 0.281250 (0.299618)  loss 1.901846 (1.944634)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-06-07 11:12:32] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [141/391]  acc1 0.312500 (0.302970)  loss 1.854658 (1.939751)\n",
      "[2021-06-07 11:12:32] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [151/391]  acc1 0.218750 (0.305257)  loss 1.927818 (1.934704)\n",
      "[2021-06-07 11:12:32] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [161/391]  acc1 0.343750 (0.307648)  loss 1.820810 (1.927533)\n",
      "[2021-06-07 11:12:33] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [171/391]  acc1 0.312500 (0.307383)  loss 1.800313 (1.924665)\n",
      "[2021-06-07 11:12:33] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [181/391]  acc1 0.484375 (0.307925)  loss 1.637479 (1.920402)\n",
      "[2021-06-07 11:12:33] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [191/391]  acc1 0.359375 (0.306692)  loss 1.732374 (1.917680)\n",
      "[2021-06-07 11:12:33] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [201/391]  acc1 0.406250 (0.309624)  loss 1.870701 (1.911484)\n",
      "[2021-06-07 11:12:33] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [211/391]  acc1 0.328125 (0.311982)  loss 1.785704 (1.905039)\n",
      "[2021-06-07 11:12:33] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [221/391]  acc1 0.265625 (0.312712)  loss 1.738683 (1.901547)\n",
      "[2021-06-07 11:12:33] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [231/391]  acc1 0.359375 (0.315409)  loss 1.827117 (1.894860)\n",
      "[2021-06-07 11:12:34] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [241/391]  acc1 0.375000 (0.317881)  loss 1.717454 (1.888916)\n",
      "[2021-06-07 11:12:34] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [251/391]  acc1 0.328125 (0.318663)  loss 1.873310 (1.886883)\n",
      "[2021-06-07 11:12:34] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [261/391]  acc1 0.390625 (0.320163)  loss 1.657088 (1.881767)\n",
      "[2021-06-07 11:12:34] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [271/391]  acc1 0.421875 (0.321264)  loss 1.710897 (1.877521)\n",
      "[2021-06-07 11:12:34] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [281/391]  acc1 0.421875 (0.321230)  loss 1.760745 (1.875136)\n",
      "[2021-06-07 11:12:34] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [291/391]  acc1 0.375000 (0.321413)  loss 1.669255 (1.872129)\n",
      "[2021-06-07 11:12:34] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [301/391]  acc1 0.328125 (0.322051)  loss 1.728873 (1.868047)\n",
      "[2021-06-07 11:12:35] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [311/391]  acc1 0.375000 (0.323000)  loss 1.754761 (1.864783)\n",
      "[2021-06-07 11:12:35] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [321/391]  acc1 0.437500 (0.324864)  loss 1.666240 (1.859164)\n",
      "[2021-06-07 11:12:35] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [331/391]  acc1 0.421875 (0.325954)  loss 1.661471 (1.856318)\n",
      "[2021-06-07 11:12:35] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [341/391]  acc1 0.328125 (0.326475)  loss 1.737106 (1.853075)\n",
      "[2021-06-07 11:12:35] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [351/391]  acc1 0.343750 (0.327724)  loss 1.789253 (1.849491)\n",
      "[2021-06-07 11:12:36] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [361/391]  acc1 0.250000 (0.328558)  loss 1.773805 (1.846033)\n",
      "[2021-06-07 11:12:36] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [371/391]  acc1 0.312500 (0.329094)  loss 1.901358 (1.844091)\n",
      "[2021-06-07 11:12:36] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [381/391]  acc1 0.250000 (0.330011)  loss 1.863921 (1.841390)\n",
      "[2021-06-07 11:12:36] INFO (nni.retiarii.oneshot.pytorch.darts/MainThread) Epoch [2/2] Step [391/391]  acc1 0.325000 (0.331514)  loss 1.729926 (1.837162)\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from utils import accuracy\n",
    "from torchvision import transforms\n",
    "from torchvision.datasets import CIFAR10\n",
    "from nni.retiarii.oneshot.pytorch import DartsTrainer\n",
    "\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)\n",
    "\n",
    "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
    "train_dataset = CIFAR10(root=\"./data\", train=True, download=True, transform=transform)\n",
    "\n",
    "trainer = DartsTrainer(\n",
    "    model=model,\n",
    "    loss=criterion,\n",
    "    metrics=lambda output, target: accuracy(output, target),\n",
    "    optimizer=optimizer,\n",
    "    num_epochs=2,\n",
    "    dataset=train_dataset,\n",
    "    batch_size=64,\n",
    "    log_frequency=10\n",
    "    )\n",
    "\n",
    "trainer.fit()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Similarly, the optimal structure found can be exported."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final architecture: {'_mutation_1': 1, '_mutation_2': 1, '_mutation_3': [1]}\n"
     ]
    }
   ],
   "source": [
    "print('Final architecture:', trainer.export())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
